{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "31d093db",
   "metadata": {},
   "source": [
    "\n",
    "下面用代码展示BERT的基本用法。\n",
    "\n",
    "下面展示给定输入为“The capital of China is \\[MASK\\]”的情况下，模型会如何预测被掩码的词。这里输出概率最高的5个词。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "026be16d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The capital of China is beijing.\n",
      "The capital of China is nanjing.\n",
      "The capital of China is shanghai.\n",
      "The capital of China is guangzhou.\n",
      "The capital of China is shenzhen.\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "代码来源于GitHub项目huggingface/transformers\n",
    "（Copyright (c) 2020, The HuggingFace Team, Apache-2.0 License（见附录））\n",
    "\"\"\"\n",
    "from transformers import BertTokenizer, BertForMaskedLM\n",
    "from torch.nn import functional as F\n",
    "import torch\n",
    "\n",
    "# 选用bert-base-uncased模型进行预测，使用相应的分词器\n",
    "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n",
    "model = BertForMaskedLM.from_pretrained('bert-base-uncased', return_dict=True)\n",
    "\n",
    "# 准备输入句子“The capital of China is [MASK].”\n",
    "text = 'The capital of China is ' + tokenizer.mask_token + '.'\n",
    "# 将输入句子编码为PyTorch张量\n",
    "inputs = tokenizer.encode_plus(text, return_tensors='pt')\n",
    "# 定位[MASK]所在的位置\n",
    "mask_index = torch.where(inputs['input_ids'][0] == tokenizer.mask_token_id)\n",
    "output = model(**inputs)\n",
    "logits = output.logits\n",
    "# 从[MASK]所在位置的输出分布中，选择概率最高的5个并打印\n",
    "distribution = F.softmax(logits, dim=-1)\n",
    "mask_word = distribution[0, mask_index, :]\n",
    "top_5 = torch.topk(mask_word, 5, dim=1)[1][0]\n",
    "for token in top_5:\n",
    "    word = tokenizer.decode([token])\n",
    "    new_sentence = text.replace(tokenizer.mask_token, word)\n",
    "    print(new_sentence)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a03f59ac",
   "metadata": {},
   "source": [
    "\n",
    "下面展示如何微调BERT用于文本分类。这里使用第4章的Books数据集。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "895c0795",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train size = 8627 , test size = 2157\n",
      "{0: '计算机类', 1: '艺术传媒类', 2: '经管类'}\n",
      "8627 2157\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 100/100 [00:00<00:00, 8225.09it/s]\n",
      "100%|██████████| 100/100 [00:00<00:00, 4294.85it/s]\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "代码来源于GitHub项目huggingface/transformers\n",
    "（Copyright (c) 2020, The HuggingFace Team, Apache-2.0 License（见附录））\n",
    "\"\"\"\n",
    "import sys\n",
    "from tqdm import tqdm\n",
    "\n",
    "# 导入前面实现的Books数据集\n",
    "sys.path.append('./code')\n",
    "from utils import BooksDataset\n",
    "\n",
    "dataset = BooksDataset()\n",
    "# 打印出类和标签ID\n",
    "print(dataset.id2label)\n",
    "print(len(dataset.train_data), len(dataset.test_data))\n",
    "\n",
    "# 接下来使用分词器进行分词，并采样100条数据用于训练和测试\n",
    "# 为防止运行时间过长，此处为了在CPU上顺利运行，只选用100条数据。\n",
    "from transformers import AutoTokenizer\n",
    "tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')\n",
    "\n",
    "def tokenize_function(text):\n",
    "    return tokenizer(text, padding='max_length', truncation=True)\n",
    "\n",
    "def tokenize(raw_data):\n",
    "    dataset = []\n",
    "    for data in tqdm(raw_data):\n",
    "        tokens = tokenize_function(data['en_book'])\n",
    "        tokens['label'] = data['label']\n",
    "        dataset.append(tokens)\n",
    "    return dataset\n",
    "        \n",
    "small_train_dataset = tokenize(dataset.train_data[:100])\n",
    "small_eval_dataset = tokenize(dataset.test_data[:100])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39353b6d",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# 加载bert-base-cased这个预训练模型，并指定序列分类作为模型输出头，\n",
    "# 分类标签数为3类\n",
    "from transformers import AutoModelForSequenceClassification\n",
    "model = AutoModelForSequenceClassification.from_pretrained(\\\n",
    "    'bert-base-cased', num_labels=len(dataset.id2label))\n",
    "\n",
    "# 为了在训练过程中及时地监控模型性能，定义评估函数，计算分类准确率\n",
    "import numpy as np\n",
    "# 可以使用如下指令安装evaluate\n",
    "# conda install evaluate\n",
    "import evaluate\n",
    "\n",
    "metric = evaluate.load('accuracy')\n",
    "\n",
    "def compute_metrics(eval_pred):\n",
    "    logits, labels = eval_pred\n",
    "    predictions = np.argmax(logits, axis=-1)\n",
    "    return metric.compute(predictions=predictions, references=labels)\n",
    "\n",
    "# 通过TrainingArguments这个类来构造训练所需的参数\n",
    "# evaluation_strategy='epoch'指定每个epoch结束的时候计算评价指标\n",
    "from transformers import TrainingArguments, Trainer\n",
    "training_args = TrainingArguments(output_dir='test_trainer',\\\n",
    "    evaluation_strategy='epoch')\n",
    "\n",
    "# transformers这个库自带的Trainer类封装了大量模型训练的细节，\n",
    "# 例如数据转换、性能评测、保存模型等\n",
    "# 可以调用Trainer类来非常方便地调用标准的微调流程，默认训练3个epoch\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=small_train_dataset,\n",
    "    eval_dataset=small_eval_dataset,\n",
    "    compute_metrics=compute_metrics,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "337b979a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='39' max='39' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [39/39 12:20, Epoch 3/3]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Accuracy</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>No log</td>\n",
       "      <td>0.962486</td>\n",
       "      <td>0.520000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>No log</td>\n",
       "      <td>0.852982</td>\n",
       "      <td>0.670000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>No log</td>\n",
       "      <td>0.816384</td>\n",
       "      <td>0.680000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 默认的微调流程使用wandb记录训练log，访问wandb官网了解如何使用\n",
    "# 此处通过WANDB_DISABLED环境变量禁用wandb，减少不必要的网络访问\n",
    "import os\n",
    "os.environ[\"WANDB_DISABLED\"] = \"true\"\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb019379",
   "metadata": {},
   "source": [
    "以上代码通过调用Trainer类来实现简单的微调流程，接下来展示如何自定义微调流程。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "7c8f4428",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2f7c36a2ba7e40bfb645c9f49e543810",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "del model\n",
    "del trainer\n",
    "# 如果你使用了GPU，清空GPU缓存\n",
    "torch.cuda.empty_cache()\n",
    "\n",
    "# 使用DataLoader类为模型提供数据\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "# 将Python列表转为PyTorch张量\n",
    "def collate(batch):\n",
    "    input_ids, token_type_ids, attention_mask, labels = [], [], [], []\n",
    "    for d in batch:\n",
    "        input_ids.append(d['input_ids'])\n",
    "        token_type_ids.append(d['token_type_ids'])\n",
    "        attention_mask.append(d['attention_mask'])\n",
    "        labels.append(d['label'])\n",
    "    input_ids = torch.tensor(input_ids)\n",
    "    token_type_ids = torch.tensor(token_type_ids)\n",
    "    attention_mask = torch.tensor(attention_mask)\n",
    "    labels = torch.tensor(labels)\n",
    "    return {'input_ids': input_ids, 'token_type_ids': token_type_ids,\\\n",
    "        'attention_mask': attention_mask, 'labels': labels}\n",
    "\n",
    "train_dataloader = DataLoader(small_train_dataset, shuffle=True,\\\n",
    "    batch_size=8, collate_fn=collate)\n",
    "eval_dataloader = DataLoader(small_eval_dataset, batch_size=8,\\\n",
    "    collate_fn=collate)\n",
    "\n",
    "# 载入模型，准备优化器（用于优化参数），以及scheduler\n",
    "# （在训练时调整学习率，以达到更好的微调效果）\n",
    "from transformers import AutoModelForSequenceClassification\n",
    "model = AutoModelForSequenceClassification.from_pretrained(\\\n",
    "    \"bert-base-cased\", num_labels=len(dataset.id2label))\n",
    "\n",
    "from torch.optim import AdamW\n",
    "optimizer = AdamW(model.parameters(), lr=5e-5)\n",
    "\n",
    "from transformers import get_scheduler\n",
    "num_epochs = 3\n",
    "num_training_steps = num_epochs * len(train_dataloader)\n",
    "lr_scheduler = get_scheduler(\n",
    "    name=\"linear\", optimizer=optimizer, num_warmup_steps=0,\\\n",
    "    num_training_steps=num_training_steps\n",
    ")\n",
    "\n",
    "import torch\n",
    "# 自动判断是否有GPU可以使用，如果可用，将model移动到GPU显存中\n",
    "device = torch.device(\"cuda\") if torch.cuda.is_available()\\\n",
    "    else torch.device(\"cpu\")\n",
    "model.to(device)\n",
    "\n",
    "# 训练流程\n",
    "from tqdm.auto import tqdm\n",
    "progress_bar = tqdm(range(num_training_steps))\n",
    "\n",
    "for epoch in range(num_epochs):\n",
    "    # 在每个epoch开始时将model的is_training设为True，\n",
    "    # 该变量将会影响到dropout等层的行为（训练时开启dropout）\n",
    "    model.train()\n",
    "    for batch in train_dataloader:\n",
    "        # 如果GPU可用，这一步将把数据转移到GPU显存中\n",
    "        batch = {k: v.to(device) for k, v in batch.items()}\n",
    "        outputs = model(**batch)\n",
    "        loss = outputs.loss\n",
    "        loss.backward()\n",
    "\n",
    "        optimizer.step()\n",
    "        lr_scheduler.step()\n",
    "        # 更新参数之后清除上一步的梯度\n",
    "        optimizer.zero_grad()\n",
    "        progress_bar.update(1)\n",
    "progress_bar.close()\n",
    "import evaluate\n",
    "\n",
    "# 训练结束时对测试集进行评估，得到模型分数\n",
    "model.eval()\n",
    "metric = evaluate.load(\"accuracy\")\n",
    "for batch in eval_dataloader:\n",
    "    batch = {k: v.to(device) for k, v in batch.items()}\n",
    "    with torch.no_grad():\n",
    "        outputs = model(**batch)\n",
    "\n",
    "    logits = outputs.logits\n",
    "    predictions = torch.argmax(logits, dim=-1)\n",
    "    metric.add_batch(predictions=predictions, references=batch[\"labels\"])\n",
    "acc = metric.compute()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f871e435",
   "metadata": {},
   "source": [
    "\n",
    "下面的代码演示了如何使用GPT-2进行训练。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d31da145",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "19206 4802\n",
      "['the', 'Ġlittle', 'Ġprince', 'Ġ', 'ĊĊ', 'Ċ', 'Ċ', 'anto', 'ine', 'Ġde']\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "代码来源于GitHub项目huggingface/transformers\n",
    "（Copyright (c) 2020, The HuggingFace Team, Apache-2.0 License（见附录））\n",
    "\"\"\"\n",
    "import sys\n",
    "\n",
    "# 导入第3章使用的《小王子》数据集\n",
    "sys.path.append('../code')\n",
    "from utils import TheLittlePrinceDataset\n",
    "\n",
    "full_text = TheLittlePrinceDataset(tokenize=False).text\n",
    "# 接下来载入GPT2模型的分词器并完成分词。\n",
    "from transformers import AutoTokenizer\n",
    "tokenizer = AutoTokenizer.from_pretrained('gpt2')\n",
    "\n",
    "full_tokens = tokenizer.tokenize(full_text.lower())\n",
    "train_size = int(len(full_tokens) * 0.8)\n",
    "train_tokens = full_tokens[:train_size]\n",
    "test_tokens = full_tokens[train_size:]\n",
    "print(len(train_tokens), len(test_tokens))\n",
    "print(train_tokens[:10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b301b7e6",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.utils.data import TensorDataset\n",
    "\n",
    "# 将文本根据block_size分成小块\n",
    "block_size = 128\n",
    "\n",
    "def split_blocks(tokens):\n",
    "    token_ids = []\n",
    "    for i in range(len(tokens) // block_size):\n",
    "        _tokens = tokens[i*block_size:(i+1)*block_size]\n",
    "        if len(_tokens) < block_size:\n",
    "            _tokens += [tokenizer.pad_token] * (block_size - len(_tokens))\n",
    "        _token_ids = tokenizer.convert_tokens_to_ids(_tokens)\n",
    "        token_ids.append(_token_ids)\n",
    "    return token_ids\n",
    "\n",
    "train_dataset = split_blocks(train_tokens)\n",
    "test_dataset = split_blocks(test_tokens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "94647ae7",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='57' max='57' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [57/57 04:51, Epoch 3/3]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>No log</td>\n",
       "      <td>3.240260</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>No log</td>\n",
       "      <td>3.152012</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>No log</td>\n",
       "      <td>3.125814</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='5' max='5' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [5/5 00:05]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Perplexity: 22.78\n"
     ]
    }
   ],
   "source": [
    "# 创建一个DataCollator，用于在训练时把分词的结果转化为模型可以训练的张量\n",
    "# 注意此时微调的任务是语言模型，而不是掩码语言模型\n",
    "from transformers import DataCollatorForLanguageModeling\n",
    "\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "data_collator = DataCollatorForLanguageModeling(tokenizer=\\\n",
    "    tokenizer, mlm=False)\n",
    "\n",
    "# 导入模型，准备训练参数，调用Trainer类完成训练\n",
    "from transformers import AutoModelForCausalLM, TrainingArguments, Trainer\n",
    "\n",
    "model = AutoModelForCausalLM.from_pretrained(\"gpt2\")\n",
    "\n",
    "training_args = TrainingArguments(\n",
    "    output_dir=\"test_trainer\",\n",
    "    evaluation_strategy=\"epoch\",\n",
    "    learning_rate=2e-5,\n",
    "    weight_decay=0.01,\n",
    ")\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    train_dataset=train_dataset,\n",
    "    eval_dataset=test_dataset,\n",
    "    data_collator=data_collator,\n",
    ")\n",
    "\n",
    "trainer.train()\n",
    "\n",
    "# 在测试集上测试得到困惑度\n",
    "import math\n",
    "eval_results = trainer.evaluate()\n",
    "print(f\"Perplexity: {math.exp(eval_results['eval_loss']):.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dd0dbdb9",
   "metadata": {},
   "source": [
    "这里基于HuggingFace来展示如何使用GPT-2模型生成文本。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "399168ba",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Output:\n",
      "----------------------------------------------------------------------------------------------------\n",
      "I enjoy learning with this book. I have been reading it for a while now and I am very happy with it. I have been reading it for a while now and I am very happy with it.\n",
      "\n",
      "I have been reading it for a\n",
      "Output:\n",
      "----------------------------------------------------------------------------------------------------\n",
      "I enjoy learning with this book, and I hope you enjoy reading it as much as I do.\n",
      "\n",
      "I hope you enjoy reading this book, and I hope you enjoy reading it as much as I do.\n",
      "\n",
      "I hope you enjoy reading\n",
      "Output:\n",
      "----------------------------------------------------------------------------------------------------\n",
      "0: I enjoy learning with this book, and I hope you enjoy reading it as much as I do.\n",
      "\n",
      "If you have any questions or comments, feel free to leave them in the comments below.\n",
      "1: I enjoy learning with this book, and I hope you enjoy reading it as much as I do.\n",
      "\n",
      "If you have any questions or comments, please feel free to leave them in the comments below.\n",
      "2: I enjoy learning with this book, and I hope you enjoy reading it as much as I do.\n",
      "\n",
      "If you have any questions or comments, feel free to leave them in the comment section below.\n",
      "3: I enjoy learning with this book, and I hope you enjoy reading it as much as I do.\n",
      "\n",
      "If you have any questions or comments, feel free to leave them in the comments section below.\n",
      "4: I enjoy learning with this book, and I hope you enjoy reading it as much as I do.\n",
      "\n",
      "If you have any questions or comments, feel free to leave them in the comments below!\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "代码来源于GitHub项目huggingface/transformers\n",
    "（Copyright (c) 2020, The HuggingFace Team, Apache-2.0 License（见附录））\n",
    "\"\"\"\n",
    "import torch\n",
    "from transformers import GPT2LMHeadModel, GPT2Tokenizer\n",
    "\n",
    "tokenizer = GPT2Tokenizer.from_pretrained('gpt2')\n",
    "model = GPT2LMHeadModel.from_pretrained('gpt2',\\\n",
    "    pad_token_id=tokenizer.eos_token_id)\n",
    "# 输入文本\n",
    "input_ids = tokenizer.encode('I enjoy learning with this book',\\\n",
    "    return_tensors='pt')\n",
    "\n",
    "# 输出文本\n",
    "greedy_output = model.generate(input_ids, max_length=50)\n",
    "print(\"Output:\\n\" + 100 * '-')\n",
    "print(tokenizer.decode(greedy_output[0], skip_special_tokens=True))\n",
    "\n",
    "# 通过束搜索来生成句子，一旦生成足够多的句子即停止搜索\n",
    "beam_output = model.generate(\n",
    "    input_ids, \n",
    "    max_length=50, \n",
    "    num_beams=5, \n",
    "    early_stopping=True\n",
    ")\n",
    "\n",
    "print(\"Output:\\n\" + 100 * '-')\n",
    "print(tokenizer.decode(beam_output[0], skip_special_tokens=True))\n",
    "\n",
    "# 输出多个句子\n",
    "beam_outputs = model.generate(\n",
    "    input_ids, \n",
    "    max_length=50, \n",
    "    num_beams=5, \n",
    "    no_repeat_ngram_size=2, \n",
    "    num_return_sequences=5, \n",
    "    early_stopping=True\n",
    ")\n",
    "\n",
    "print(\"Output:\\n\" + 100 * '-')\n",
    "for i, beam_output in enumerate(beam_outputs):\n",
    "    print(\"{}: {}\".format(i, tokenizer.decode(beam_output,\\\n",
    "        skip_special_tokens=True)))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "00d323e9",
   "metadata": {},
   "source": [
    "HuggingFace中集成了许多预训练语言模型。你可以直接通过具体的接口调用某一个预训练语言模型，但这种方式相对复杂，需要对具体模型和接口有所了解。或者，你也可以通过pipeline模块黑箱地使用这些模型，pipeline模块会根据指定的任务自动分配一个合适的预训练语言模型，你也可以通过参数指定一个预训练语言模型。下面演示pipeline模块处理不同任务的代码，你也可以在HuggingFace官网上了解HuggingFace支持哪些模型。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5fcc3533",
   "metadata": {},
   "source": [
    "\n",
    "\n",
    "下面以情感分类为例演示文本分类任务上预训练语言模型的使用。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "62235e0e",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "No model was supplied, defaulted to distilbert-base-uncased-finetuned-sst-2-english...\n",
      "No model was supplied, defaulted to facebook/bart-large-mnli..."
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[{'label': 'POSITIVE', 'score': 0.9998708963394165}]\n",
      "[{'label': 'POSITIVE', 'score': 0.9998835325241089}, {'label': 'NEGATIVE', 'score': 0.9994825124740601}, {'label': 'POSITIVE', 'score': 0.9998630285263062}]\n",
      "[{'sequence': 'A helicopter is flying in the sky', 'labels': ['machine', 'animal'], 'scores': [0.9938627481460571, 0.006137245334684849]}, {'sequence': 'A bird is flying in the sky', 'labels': ['animal', 'machine'], 'scores': [0.9987970590591431, 0.001202935236506164]}]\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "代码来源于GitHub项目huggingface/transformers\n",
    "（Copyright (c) 2020, The HuggingFace Team, Apache-2.0 License（见附录））\n",
    "\"\"\"\n",
    "from transformers import pipeline\n",
    "\n",
    "clf = pipeline('sentiment-analysis')\n",
    "print(clf('Haha, today is a nice day!'))\n",
    "\n",
    "print(clf(['The food is amazing', 'The assignment is weigh too hard',\\\n",
    "           'NLP is so much fun']))\n",
    "\n",
    "clf = pipeline('zero-shot-classification')\n",
    "print(clf(sequences=['A helicopter is flying in the sky',\\\n",
    "                     'A bird is flying in the sky'],\n",
    "   candidate_labels=['animal', 'machine']))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3008cc7b",
   "metadata": {},
   "source": [
    "\n",
    "下面演示两种文本生成任务上预训练语言模型的使用。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "71d72551",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "No model was supplied, defaulted to gpt2...\n",
      "No model was supplied, defaulted to distilroberta-base...\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[{'generated_text': \"In this course, we will teach you how to get started with the code of a given app. This way, you will build new apps that fit your needs but are still well behaved and understandable (unless you're already using Swift to understand the language\"}]\n",
      "[{'score': 0.1961982101202011, 'token': 30412, 'token_str': ' mathematical', 'sequence': 'This course will teach you all about mathematical models.'}, {'score': 0.040527306497097015, 'token': 38163, 'token_str': ' computational', 'sequence': 'This course will teach you all about computational models.'}, {'score': 0.033017922192811966, 'token': 27930, 'token_str': ' predictive', 'sequence': 'This course will teach you all about predictive models.'}, {'score': 0.0319414846599102, 'token': 745, 'token_str': ' building', 'sequence': 'This course will teach you all about building models.'}, {'score': 0.024523010477423668, 'token': 3034, 'token_str': ' computer', 'sequence': 'This course will teach you all about computer models.'}]\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "代码来源于GitHub项目huggingface/transformers\n",
    "（Copyright (c) 2020, The HuggingFace Team, Apache-2.0 License（见附录））\n",
    "\"\"\"\n",
    "generator = pipeline('text-generation')\n",
    "print(generator('In this course, we will teach you how to'))\n",
    "\n",
    "unmasker = pipeline('fill-mask')\n",
    "print(unmasker('This course will teach you all about <mask> models.'))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f61e334c",
   "metadata": {},
   "source": [
    "\n",
    "\n",
    "输入任务“question-answering”，pipeline会自动返回默认的问答预训练语言模型“distilbert-base-cased-distilled-squad”，输入问题和上下文，就能得到答案。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "3f943e42",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "No model was supplied, defaulted to distilbert-base-cased-distilled-squad...\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'score': 0.7787413597106934, 'start': 34, 'end': 63, 'answer': 'Shanghai Jiao Tong University'}\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "代码来源于GitHub项目huggingface/transformers\n",
    "（Copyright (c) 2020, The HuggingFace Team, Apache-2.0 License（见附录））\n",
    "\"\"\"\n",
    "question_answerer = pipeline('question-answering')\n",
    "print(question_answerer(question='Where do I graduate from?', \n",
    "    context=\"I received my bachlor\\'s degree at Shanghai\"+\\\n",
    "        \"Jiao Tong University (SJTU).\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "68098304",
   "metadata": {},
   "source": [
    "\n",
    "\n",
    "输入任务“summarization”，pipeline会自动返回默认的预训练语言模型“sshleifer/distilbart-cnn-12-6”，输入一段文本，就能得到摘要。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "1977a486",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "No model was supplied, defaulted to sshleifer/distilbart-cnn-12-6...\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[{'summary_text': \" The 2022 Winter Olympics was held in Beijing, China, and surrounding areas . It was the 24th edition of the Winter Olympic Games . The Games featured a record 109 events across 15 disciplines, with big air freestyle skiing and women's monobob making their Olympic debuts as medal events . Norway won 37 medals, of which 16 were gold, setting a new record for the largest number of gold medals won at a single Winter Olympics .\"}]\n"
     ]
    }
   ],
   "source": [
    "\"\"\"\n",
    "代码来源于GitHub项目huggingface/transformers\n",
    "（Copyright (c) 2020, The HuggingFace Team, Apache-2.0 License（见附录））\n",
    "\"\"\"\n",
    "summarizer = pipeline('summarization')\n",
    "print(summarizer(\n",
    "    \"\"\"\n",
    "    The 2022 Winter Olympics (2022年冬季奥林匹克运动会), officially \n",
    "    called the XXIV Olympic Winter Games (Chinese: 第二十四届冬季奥\n",
    "    林匹克运动会; pinyin: Dì Èrshísì Jiè Dōngjì Àolínpǐkè Yùndònghuì) \n",
    "    and commonly known as Beijing 2022 (北京2022), was an international \n",
    "    winter multi-sport event held from 4 to 20 February 2022 in Beijing, \n",
    "    China, and surrounding areas with competition in selected events \n",
    "    beginning 2 February 2022.[1] It was the 24th edition of the Winter \n",
    "    Olympic Games. Beijing was selected as host city in 2015 at the \n",
    "    128th IOC Session in Kuala Lumpur, Malaysia, marking its second \n",
    "    time hosting the Olympics, and the last of three consecutive \n",
    "    Olympics hosted in East Asia following the 2018 Winter Olympics \n",
    "    in Pyeongchang County, South Korea, and the 2020 Summer Olympics \n",
    "    in Tokyo, Japan. Having previously hosted the 2008 Summer Olympics, \n",
    "    Beijing became the first city to have hosted both the Summer and \n",
    "    Winter Olympics. The venues for the Games were concentrated around \n",
    "    Beijing, its suburb Yanqing District, and Zhangjiakou, with some \n",
    "    events (including the ceremonies and curling) repurposing venues \n",
    "    originally built for Beijing 2008 (such as Beijing National \n",
    "    Stadium and the Beijing National Aquatics Centre). The Games \n",
    "    featured a record 109 events across 15 disciplines, with big air \n",
    "    freestyle skiing and women's monobob making their Olympic debuts \n",
    "    as medal events, as well as several new mixed competitions. \n",
    "    A total of 2,871 athletes representing 91 teams competed in the \n",
    "    Games, with Haiti and Saudi Arabia making their Winter Olympic \n",
    "    debut. Norway finished at the top of the medal table \n",
    "    for the second successive Winter Olympics, winning a total of 37 \n",
    "    medals, of which 16 were gold, setting a new record for the \n",
    "    largest number of gold medals won at a single Winter Olympics. \n",
    "    The host nation China finished third with nine gold medals and \n",
    "    also eleventh place by total medals won, marking its most \n",
    "    successful performance in Winter Olympics history.[4]\n",
    "    \"\"\"\n",
    "))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
